#!/usr/bin/env python3
"""
Phase 5b: GE clearing on 140-country sample
Phase 5c: Full projections on 140-country baseline model
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS
from src.macro import filter_eba_sample
from src.scenarios import generate_projection_table

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"
PROCESSED_DIR = FOLLOWUP_DIR / "data" / "processed"

# ─── Load data ───────────────────────────────────────────────────────────
panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
ORIG_PROCESSED = PROJECT_DIR / "data" / "processed"
polys = pd.read_csv(ORIG_PROCESSED / "demographic_polynomials.csv")

# Load 140-country baseline coefficients
coeffs = pd.read_csv(OUTPUT_DIR / "regression_baseline_demo_plus_eba_140.csv")
coeff_map = dict(zip(coeffs['variable'], coeffs['coefficient']))
z_names = ['Z_1', 'Z_2', 'Z_3']
z_betas = {z: coeff_map[z] for z in z_names}
print(f"Baseline Z coefficients: {z_betas}")

feature_names = list(coeffs['variable'])
beta_array = np.array(coeffs['coefficient'])

# Create a mock model object with beta and feature_names
class MockModel:
    def __init__(self, beta, feature_names):
        self.beta = beta
        self.feature_names = feature_names

model = MockModel(beta_array, feature_names)

# ═══════════════════════════════════════════════════════════════════════════
# PHASE 5c: FULL PROJECTIONS ON 140-COUNTRY BASELINE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("PHASE 5c: DEMOGRAPHIC PROJECTIONS (140-COUNTRY BASELINE)")
print("=" * 70)

# Compute demographic contributions for all countries, all years
z_indices = [i for i, name in enumerate(feature_names) if name.startswith('Z_')]
z_coefs = beta_array[z_indices]
z_feat = [feature_names[i] for i in z_indices]

proj_all = polys[polys['year'].between(1970, 2100)].copy()
proj_all['demo_contribution'] = sum(
    z_coefs[i] * proj_all[z_feat[i]] for i in range(len(z_feat))
)
print(f"Projections computed for {proj_all['iso3'].nunique()} countries, "
      f"{proj_all['year'].min()}-{proj_all['year'].max()}")

# Filter to estimation sample countries
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()
full_sample = filter_eba_sample(est, extended=True, expansion=True)
est_countries = full_sample['iso3'].unique()
proj_est = proj_all[proj_all['iso3'].isin(est_countries)].copy()
print(f"Filtered to {proj_est['iso3'].nunique()} estimation countries")

# Build profiles dict (like compute_country_profiles but without full model)
all_countries = sorted(proj_est['iso3'].unique())
profiles = {}
for iso3 in all_countries:
    cdf = proj_est[proj_est['iso3'] == iso3][['iso3', 'year', 'demo_contribution'] + z_feat].copy()
    if len(cdf) > 0:
        profiles[iso3] = cdf

# Generate projection table for key countries
focus_countries = [
    # Original paper focus
    'CHN', 'IND', 'IDN', 'JPN', 'USA', 'DEU', 'BRA', 'NGA', 'ZAF',
    'KOR', 'GBR', 'AUS', 'MEX', 'TUR', 'SAU',
    # Expansion highlights
    'IRN', 'VNM', 'BGD', 'GTM', 'DOM', 'ROU', 'UKR', 'DZA', 'SVN',
    'KAZ', 'QAT', 'CRI', 'ECU', 'PHL', 'ARG', 'EGY',
    # New from 140 expansion
    'GEO', 'ARM', 'KHM', 'NPL', 'BOL', 'JAM', 'ALB', 'MDA',
]

years_show = [2000, 2010, 2020, 2025, 2030, 2040, 2050, 2060]
proj_table = generate_projection_table(profiles, countries=focus_countries, years_to_show=years_show)
proj_table.to_csv(OUTPUT_DIR / "projection_table_140.csv", index=False)
print(f"\nProjection table ({len(proj_table)} countries):")
print(proj_table.to_string(index=False, float_format='%.2f'))

# Compute inflection points (when demo_contribution crosses zero or reaches extremum)
inflection_rows = []
for iso3 in all_countries:
    cdf = profiles[iso3]
    future = cdf[cdf['year'] >= 2020].sort_values('year')
    if len(future) < 5:
        continue

    dc = future['demo_contribution'].values
    yrs = future['year'].values

    # Current value (2025 or closest)
    idx_2025 = np.argmin(np.abs(yrs - 2025))
    current = dc[idx_2025]

    # Peak and trough
    peak_idx = np.argmax(dc)
    trough_idx = np.argmin(dc)

    # Zero crossing (if any)
    zero_year = None
    for j in range(len(dc) - 1):
        if dc[j] * dc[j+1] < 0:
            # Linear interpolation
            frac = abs(dc[j]) / (abs(dc[j]) + abs(dc[j+1]))
            zero_year = yrs[j] + frac * (yrs[j+1] - yrs[j])
            break

    # 2025-2050 swing
    idx_2050 = np.argmin(np.abs(yrs - 2050))
    swing = dc[idx_2050] - dc[idx_2025]

    inflection_rows.append({
        'iso3': iso3,
        'dc_2025': current,
        'dc_peak': dc[peak_idx],
        'peak_year': yrs[peak_idx],
        'dc_trough': dc[trough_idx],
        'trough_year': yrs[trough_idx],
        'zero_crossing': zero_year,
        'swing_2025_2050': swing,
    })

inflection_df = pd.DataFrame(inflection_rows)
inflection_df = inflection_df.sort_values('swing_2025_2050')
inflection_df.to_csv(OUTPUT_DIR / "inflection_points_140.csv", index=False)

print(f"\n\nLargest negative swings (2025-2050):")
for _, r in inflection_df.head(15).iterrows():
    zc = f"crosses zero ~{int(r['zero_crossing'])}" if pd.notna(r['zero_crossing']) else "no zero crossing"
    print(f"  {r['iso3']:>4}: {r['swing_2025_2050']:+.2f}pp  (2025: {r['dc_2025']:+.2f}, {zc})")

print(f"\nLargest positive swings (2025-2050):")
for _, r in inflection_df.tail(15).iterrows():
    zc = f"crosses zero ~{int(r['zero_crossing'])}" if pd.notna(r['zero_crossing']) else "no zero crossing"
    print(f"  {r['iso3']:>4}: {r['swing_2025_2050']:+.2f}pp  (2025: {r['dc_2025']:+.2f}, {zc})")

# Group summary
# Classify countries by direction
pos_current = inflection_df[inflection_df['dc_2025'] > 0.5]
neg_current = inflection_df[inflection_df['dc_2025'] < -0.5]
near_zero = inflection_df[(inflection_df['dc_2025'] >= -0.5) & (inflection_df['dc_2025'] <= 0.5)]

print(f"\n\nCountry groups by current demographic position (2025):")
print(f"  Demographic surplus (>+0.5pp): {len(pos_current)} countries")
print(f"  Demographic deficit (<-0.5pp): {len(neg_current)} countries")
print(f"  Near zero (±0.5pp):            {len(near_zero)} countries")

# Save full projections
proj_est.to_csv(OUTPUT_DIR / "demographic_contributions_140.csv", index=False)
print(f"\nFull projections saved: {len(proj_est):,} rows")


# ═══════════════════════════════════════════════════════════════════════════
# PHASE 5b: GE CLEARING ON 140-COUNTRY SAMPLE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("PHASE 5b: GE CAPITAL MARKET CLEARING (140-COUNTRY SAMPLE)")
print("=" * 70)

# S1: Estimate demographics → bond yields
est_full = filter_eba_sample(est, extended=True, expansion=True)
s1_vars = z_names + ['fiscal_bal_gdp', 'expected_growth', 'nfa_gdp_lag', 'log_rel_opw']
s1_vars = [v for v in s1_vars if v in est_full.columns]
s1_df = est_full.dropna(subset=['real_bond_10y_diff'] + s1_vars)

print(f"\nS1 estimation sample: {len(s1_df)} obs, {s1_df['iso3'].nunique()} countries")

s1 = PanelGLS()
s1.fit(s1_df['real_bond_10y_diff'].values, s1_df[s1_vars].values,
       s1_df['iso3'].values, s1_df['year'].values)

s1_z_betas = {}
for i, v in enumerate(s1_vars):
    if v in z_names:
        s1_z_betas[v] = s1.beta[i]
        sig = '***' if s1.pvalues[i] < 0.001 else '**' if s1.pvalues[i] < 0.01 else '*' if s1.pvalues[i] < 0.05 else ''
        print(f"  S1 {v}: {s1.beta[i]:.4f} (p={s1.pvalues[i]:.4f}) {sig}")

# Rate-to-CA coefficient from Model 3b
delta_rate = 0.127
print(f"  Rate-to-CA semi-elasticity (δ): {delta_rate}")

# GDP weights
gdp_data = panel[['iso3', 'year', 'ngdp_usd']].dropna()
latest_year = gdp_data['year'].max()
gdp_weights = gdp_data[gdp_data['year'] == latest_year][['iso3', 'ngdp_usd']].copy()
total_gdp = gdp_weights['ngdp_usd'].sum()
gdp_weights['weight'] = gdp_weights['ngdp_usd'] / total_gdp
weight_map = dict(zip(gdp_weights['iso3'], gdp_weights['weight']))
print(f"  GDP weights from {latest_year}: {len(weight_map)} countries")
top5 = sorted(weight_map.items(), key=lambda x: -x[1])[:5]
print(f"  Top 5: {[(c, f'{w:.3f}') for c, w in top5]}")

# Project PE CAs and yield effects
proj_years = list(range(2000, 2065, 5))
countries_with_weight = [c for c in all_countries if c in weight_map]

rows = []
for iso3 in countries_with_weight:
    w = weight_map[iso3]
    cdf = polys[polys['iso3'] == iso3]
    for year in proj_years:
        yr = cdf[cdf['year'] == year]
        if len(yr) == 0:
            continue

        demo_ca = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)
        demo_yield = sum(s1_z_betas.get(zv, 0) * yr[zv].values[0] for zv in z_names)

        rows.append({
            'iso3': iso3,
            'year': year,
            'weight': w,
            'demo_ca_pe': demo_ca,
            'demo_yield_effect': demo_yield,
        })

ge_proj = pd.DataFrame(rows)
print(f"\n  GE projection matrix: {len(ge_proj)} country-years, "
      f"{ge_proj['iso3'].nunique()} countries")

# Compute clearing rate
clearing_rows = []
for year in proj_years:
    yr_data = ge_proj[ge_proj['year'] == year].copy()
    if len(yr_data) == 0:
        continue

    w_sum = yr_data['weight'].sum()
    yr_data['w_norm'] = yr_data['weight'] / w_sum

    pe_imbalance = (yr_data['w_norm'] * yr_data['demo_ca_pe']).sum()
    global_yield_effect = (yr_data['w_norm'] * yr_data['demo_yield_effect']).sum()

    delta_r_uncapped = pe_imbalance / delta_rate + global_yield_effect
    max_delta_r = 2.0
    delta_r_world = np.clip(delta_r_uncapped, -max_delta_r, max_delta_r)
    residual_imbalance = pe_imbalance + delta_rate * (global_yield_effect - delta_r_world)

    clearing_rows.append({
        'year': year,
        'pe_global_imbalance': pe_imbalance,
        'global_yield_effect': global_yield_effect,
        'delta_r_uncapped': delta_r_uncapped,
        'delta_r_world': delta_r_world,
        'residual_imbalance': residual_imbalance,
        'pct_cleared': (1 - abs(residual_imbalance) / max(abs(pe_imbalance), 0.001)) * 100,
        'n_countries': len(yr_data),
        'weight_coverage': w_sum,
    })

clearing = pd.DataFrame(clearing_rows)

print(f"\n  Clearing rate adjustment by year:")
print(f"  {'Year':>6} {'PE Imbal':>10} {'Δr*(pp)':>10} {'Δr* uncap':>10} {'%cleared':>10} {'N ctry':>8}")
for _, r in clearing.iterrows():
    print(f"  {int(r['year']):>6} {r['pe_global_imbalance']:>10.3f} "
          f"{r['delta_r_world']:>10.3f} {r['delta_r_uncapped']:>10.3f} "
          f"{r['pct_cleared']:>10.1f} {int(r['n_countries']):>8}")

# GE-adjusted projections
delta_r_map = dict(zip(clearing['year'], clearing['delta_r_world']))
ge_proj['delta_r_world'] = ge_proj['year'].map(delta_r_map)
ge_proj['rate_channel_pe'] = delta_rate * ge_proj['demo_yield_effect']
ge_proj['rate_channel_ge'] = delta_rate * (ge_proj['demo_yield_effect'] - ge_proj['delta_r_world'])
ge_proj['demo_ca_ge'] = ge_proj['demo_ca_pe'] + ge_proj['rate_channel_ge']
ge_proj['ge_adjustment'] = ge_proj['demo_ca_ge'] - ge_proj['demo_ca_pe']

# Print focus country comparison
focus_ge = ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'IDN', 'NGA', 'BRA',
            'IRN', 'VNM', 'BGD', 'GTM', 'SAU']
print(f"\n  PE vs GE demographic CA contributions:")
print(f"  {'Country':>8} {'':>4} {'2020':>8} {'2030':>8} {'2040':>8} {'2050':>8} {'2060':>8}")
for iso3 in focus_ge:
    cdf = ge_proj[ge_proj['iso3'] == iso3]
    if len(cdf) == 0:
        continue
    pe_vals = dict(zip(cdf['year'].astype(int), cdf['demo_ca_pe']))
    ge_vals = dict(zip(cdf['year'].astype(int), cdf['demo_ca_ge']))
    pe_str = ''.join(f"{pe_vals.get(y, np.nan):>8.2f}" for y in [2020, 2030, 2040, 2050, 2060])
    ge_str = ''.join(f"{ge_vals.get(y, np.nan):>8.2f}" for y in [2020, 2030, 2040, 2050, 2060])
    print(f"  {iso3:>8} {'PE':>4} {pe_str}")
    print(f"  {'':>8} {'GE':>4} {ge_str}")
    adj_str = ''.join(f"{ge_vals.get(y, 0) - pe_vals.get(y, 0):>+8.2f}" for y in [2020, 2030, 2040, 2050, 2060])
    print(f"  {'':>8} {'Δ':>4} {adj_str}")

# Save
ge_proj.to_csv(OUTPUT_DIR / "ge_clearing_projections_140.csv", index=False)
clearing.to_csv(OUTPUT_DIR / "ge_clearing_rates_140.csv", index=False)
print(f"\nSaved GE clearing results")


# ═══════════════════════════════════════════════════════════════════════════
# COMPARISON WITH ORIGINAL 69-COUNTRY GE CLEARING
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("COMPARISON: 69-COUNTRY vs 140-COUNTRY GE CLEARING")
print("=" * 70)

orig_clearing = None
orig_path = PROJECT_DIR / "output" / "tables" / "ge_clearing_rates.csv"
if orig_path.exists():
    orig_clearing = pd.read_csv(orig_path)
    print(f"\n  {'Year':>6} {'Orig Δr*':>10} {'140c Δr*':>10} {'Diff':>10}")
    for year in proj_years:
        orig_row = orig_clearing[orig_clearing['year'] == year]
        new_row = clearing[clearing['year'] == year]
        if len(orig_row) > 0 and len(new_row) > 0:
            orig_r = orig_row['delta_r_world'].values[0]
            new_r = new_row['delta_r_world'].values[0]
            print(f"  {year:>6} {orig_r:>10.3f} {new_r:>10.3f} {new_r - orig_r:>+10.3f}")
else:
    print("  Original GE clearing results not found for comparison")

print("\nDone.")
